{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2020-03-21T08:00:32.141518Z",
     "start_time": "2020-03-21T08:00:31.552836Z"
    }
   },
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import pandas as pd\n",
    "import warnings\n",
    "import os\n",
    "from tqdm import tqdm\n",
    "from sklearn import preprocessing, metrics\n",
    "from sklearn.preprocessing import LabelEncoder\n",
    "from sklearn.model_selection import GroupKFold, KFold, StratifiedKFold\n",
    "import gc\n",
    "import lightgbm as lgb\n",
    "import matplotlib.pyplot as plt\n",
    "import seaborn as sns\n",
    "from joblib import Parallel, delayed\n",
    "from sklearn.utils import shuffle\n",
    "\n",
    "%matplotlib inline\n",
    "\n",
    "pd.set_option('display.max_columns', None)\n",
    "pd.set_option('display.max_rows', None)\n",
    "\n",
    "warnings.filterwarnings('ignore')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2020-03-21T08:00:32.145664Z",
     "start_time": "2020-03-21T08:00:32.143183Z"
    }
   },
   "outputs": [],
   "source": [
    "seed = 2020"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2020-03-21T08:00:32.247856Z",
     "start_time": "2020-03-21T08:00:32.147000Z"
    }
   },
   "outputs": [],
   "source": [
    "df_feature = pd.read_pickle('./temp/part1_feature.plk')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2020-03-21T08:00:32.250630Z",
     "start_time": "2020-03-21T08:00:32.248937Z"
    }
   },
   "outputs": [],
   "source": [
    "# sub_columns = ['courier_id', 'wave_index', 'tracking_id',\n",
    "#                'courier_wave_start_lng', 'courier_wave_start_lat', 'action_type', 'expect_time']"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2020-03-21T08:00:32.734039Z",
     "start_time": "2020-03-21T08:00:32.251665Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "action_type\n",
      "group\n",
      "last_action_type\n",
      "weather_grade\n",
      "aoi_id\n",
      "shop_id\n"
     ]
    }
   ],
   "source": [
    "for f in df_feature.select_dtypes('object'):\n",
    "    if f not in ['date', 'type']:\n",
    "        print(f)\n",
    "        lbl = LabelEncoder()\n",
    "        df_feature[f] = lbl.fit_transform(df_feature[f].astype(str))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2020-03-21T08:00:32.753701Z",
     "start_time": "2020-03-21T08:00:32.735262Z"
    }
   },
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>courier_id</th>\n",
       "      <th>wave_index</th>\n",
       "      <th>tracking_id</th>\n",
       "      <th>courier_wave_start_lng</th>\n",
       "      <th>courier_wave_start_lat</th>\n",
       "      <th>action_type</th>\n",
       "      <th>expect_time</th>\n",
       "      <th>date</th>\n",
       "      <th>type</th>\n",
       "      <th>target</th>\n",
       "      <th>group</th>\n",
       "      <th>id</th>\n",
       "      <th>current_time</th>\n",
       "      <th>last_tracking_id</th>\n",
       "      <th>last_action_type</th>\n",
       "      <th>source_lng</th>\n",
       "      <th>source_lat</th>\n",
       "      <th>target_lng</th>\n",
       "      <th>target_lat</th>\n",
       "      <th>grid_distance</th>\n",
       "      <th>weather_grade</th>\n",
       "      <th>aoi_id</th>\n",
       "      <th>shop_id</th>\n",
       "      <th>promise_deliver_time</th>\n",
       "      <th>estimate_pick_time</th>\n",
       "      <th>level</th>\n",
       "      <th>speed</th>\n",
       "      <th>max_load</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>10007871</td>\n",
       "      <td>0</td>\n",
       "      <td>2100074550065333539</td>\n",
       "      <td>121.630997</td>\n",
       "      <td>39.142343</td>\n",
       "      <td>0</td>\n",
       "      <td>1580528963</td>\n",
       "      <td>20200201</td>\n",
       "      <td>train</td>\n",
       "      <td>1.0</td>\n",
       "      <td>9</td>\n",
       "      <td>0</td>\n",
       "      <td>1580528622</td>\n",
       "      <td>2100074550065333539</td>\n",
       "      <td>1</td>\n",
       "      <td>121.631219</td>\n",
       "      <td>39.141811</td>\n",
       "      <td>121.632084</td>\n",
       "      <td>39.146201</td>\n",
       "      <td>707.0</td>\n",
       "      <td>2</td>\n",
       "      <td>19922</td>\n",
       "      <td>5653</td>\n",
       "      <td>1580530276</td>\n",
       "      <td>1580529019</td>\n",
       "      <td>3</td>\n",
       "      <td>4.751832</td>\n",
       "      <td>11</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>10007871</td>\n",
       "      <td>0</td>\n",
       "      <td>2100074550779577850</td>\n",
       "      <td>121.630997</td>\n",
       "      <td>39.142343</td>\n",
       "      <td>1</td>\n",
       "      <td>1580529129</td>\n",
       "      <td>20200201</td>\n",
       "      <td>train</td>\n",
       "      <td>0.0</td>\n",
       "      <td>9</td>\n",
       "      <td>1</td>\n",
       "      <td>1580528622</td>\n",
       "      <td>2100074550065333539</td>\n",
       "      <td>1</td>\n",
       "      <td>121.631219</td>\n",
       "      <td>39.141811</td>\n",
       "      <td>121.631574</td>\n",
       "      <td>39.142231</td>\n",
       "      <td>152.0</td>\n",
       "      <td>2</td>\n",
       "      <td>12666</td>\n",
       "      <td>6037</td>\n",
       "      <td>1580530236</td>\n",
       "      <td>1580529399</td>\n",
       "      <td>3</td>\n",
       "      <td>4.751832</td>\n",
       "      <td>11</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>10007871</td>\n",
       "      <td>0</td>\n",
       "      <td>2100074550779577850</td>\n",
       "      <td>121.630997</td>\n",
       "      <td>39.142343</td>\n",
       "      <td>0</td>\n",
       "      <td>1580529444</td>\n",
       "      <td>20200201</td>\n",
       "      <td>train</td>\n",
       "      <td>0.0</td>\n",
       "      <td>9</td>\n",
       "      <td>2</td>\n",
       "      <td>1580528622</td>\n",
       "      <td>2100074550065333539</td>\n",
       "      <td>1</td>\n",
       "      <td>121.631219</td>\n",
       "      <td>39.141811</td>\n",
       "      <td>121.635154</td>\n",
       "      <td>39.143561</td>\n",
       "      <td>671.0</td>\n",
       "      <td>2</td>\n",
       "      <td>12666</td>\n",
       "      <td>6037</td>\n",
       "      <td>1580530236</td>\n",
       "      <td>1580529399</td>\n",
       "      <td>3</td>\n",
       "      <td>4.751832</td>\n",
       "      <td>11</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>10007871</td>\n",
       "      <td>1</td>\n",
       "      <td>2100074555638285402</td>\n",
       "      <td>121.631208</td>\n",
       "      <td>39.142519</td>\n",
       "      <td>1</td>\n",
       "      <td>1580532225</td>\n",
       "      <td>20200201</td>\n",
       "      <td>train</td>\n",
       "      <td>1.0</td>\n",
       "      <td>10</td>\n",
       "      <td>3</td>\n",
       "      <td>1580532113</td>\n",
       "      <td>2100074554932692192</td>\n",
       "      <td>0</td>\n",
       "      <td>121.636904</td>\n",
       "      <td>39.142721</td>\n",
       "      <td>121.636701</td>\n",
       "      <td>39.141801</td>\n",
       "      <td>160.0</td>\n",
       "      <td>2</td>\n",
       "      <td>14953</td>\n",
       "      <td>1113</td>\n",
       "      <td>1580533463</td>\n",
       "      <td>1580532384</td>\n",
       "      <td>3</td>\n",
       "      <td>4.751832</td>\n",
       "      <td>11</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>10007871</td>\n",
       "      <td>1</td>\n",
       "      <td>2100074554118800474</td>\n",
       "      <td>121.631208</td>\n",
       "      <td>39.142519</td>\n",
       "      <td>1</td>\n",
       "      <td>1580532227</td>\n",
       "      <td>20200201</td>\n",
       "      <td>train</td>\n",
       "      <td>0.0</td>\n",
       "      <td>10</td>\n",
       "      <td>4</td>\n",
       "      <td>1580532113</td>\n",
       "      <td>2100074554932692192</td>\n",
       "      <td>0</td>\n",
       "      <td>121.636904</td>\n",
       "      <td>39.142721</td>\n",
       "      <td>121.636701</td>\n",
       "      <td>39.141801</td>\n",
       "      <td>160.0</td>\n",
       "      <td>2</td>\n",
       "      <td>1404</td>\n",
       "      <td>1113</td>\n",
       "      <td>1580533598</td>\n",
       "      <td>1580532339</td>\n",
       "      <td>3</td>\n",
       "      <td>4.751832</td>\n",
       "      <td>11</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "   courier_id  wave_index          tracking_id  courier_wave_start_lng  \\\n",
       "0    10007871           0  2100074550065333539              121.630997   \n",
       "1    10007871           0  2100074550779577850              121.630997   \n",
       "2    10007871           0  2100074550779577850              121.630997   \n",
       "3    10007871           1  2100074555638285402              121.631208   \n",
       "4    10007871           1  2100074554118800474              121.631208   \n",
       "\n",
       "   courier_wave_start_lat  action_type  expect_time      date   type  target  \\\n",
       "0               39.142343            0   1580528963  20200201  train     1.0   \n",
       "1               39.142343            1   1580529129  20200201  train     0.0   \n",
       "2               39.142343            0   1580529444  20200201  train     0.0   \n",
       "3               39.142519            1   1580532225  20200201  train     1.0   \n",
       "4               39.142519            1   1580532227  20200201  train     0.0   \n",
       "\n",
       "   group  id  current_time     last_tracking_id  last_action_type  source_lng  \\\n",
       "0      9   0    1580528622  2100074550065333539                 1  121.631219   \n",
       "1      9   1    1580528622  2100074550065333539                 1  121.631219   \n",
       "2      9   2    1580528622  2100074550065333539                 1  121.631219   \n",
       "3     10   3    1580532113  2100074554932692192                 0  121.636904   \n",
       "4     10   4    1580532113  2100074554932692192                 0  121.636904   \n",
       "\n",
       "   source_lat  target_lng  target_lat  grid_distance  weather_grade  aoi_id  \\\n",
       "0   39.141811  121.632084   39.146201          707.0              2   19922   \n",
       "1   39.141811  121.631574   39.142231          152.0              2   12666   \n",
       "2   39.141811  121.635154   39.143561          671.0              2   12666   \n",
       "3   39.142721  121.636701   39.141801          160.0              2   14953   \n",
       "4   39.142721  121.636701   39.141801          160.0              2    1404   \n",
       "\n",
       "   shop_id  promise_deliver_time  estimate_pick_time  level     speed  \\\n",
       "0     5653            1580530276          1580529019      3  4.751832   \n",
       "1     6037            1580530236          1580529399      3  4.751832   \n",
       "2     6037            1580530236          1580529399      3  4.751832   \n",
       "3     1113            1580533463          1580532384      3  4.751832   \n",
       "4     1113            1580533598          1580532339      3  4.751832   \n",
       "\n",
       "   max_load  \n",
       "0        11  \n",
       "1        11  \n",
       "2        11  \n",
       "3        11  \n",
       "4        11  "
      ]
     },
     "execution_count": 6,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "df_feature.head()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2020-03-21T08:00:32.871363Z",
     "start_time": "2020-03-21T08:00:32.755282Z"
    }
   },
   "outputs": [],
   "source": [
    "df_test = df_feature[df_feature['type'] == 'test'].copy()\n",
    "df_train = df_feature[df_feature['type'] == 'train'].copy()\n",
    "df_train = shuffle(df_train, random_state=seed)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2020-03-21T08:00:55.351954Z",
     "start_time": "2020-03-21T08:00:32.872832Z"
    },
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "Fold_1 Training ================================\n",
      "\n",
      "Training until validation scores don't improve for 50 rounds\n",
      "Early stopping, best iteration is:\n",
      "[303]\ttrain's auc: 0.840105\tvalid's auc: 0.770636\n",
      "\n",
      "Fold_2 Training ================================\n",
      "\n",
      "Training until validation scores don't improve for 50 rounds\n",
      "Early stopping, best iteration is:\n",
      "[310]\ttrain's auc: 0.840489\tvalid's auc: 0.770277\n",
      "\n",
      "Fold_3 Training ================================\n",
      "\n",
      "Training until validation scores don't improve for 50 rounds\n",
      "Early stopping, best iteration is:\n",
      "[287]\ttrain's auc: 0.835554\tvalid's auc: 0.769252\n",
      "\n",
      "Fold_4 Training ================================\n",
      "\n",
      "Training until validation scores don't improve for 50 rounds\n",
      "Early stopping, best iteration is:\n",
      "[346]\ttrain's auc: 0.846797\tvalid's auc: 0.77001\n",
      "\n",
      "Fold_5 Training ================================\n",
      "\n",
      "Training until validation scores don't improve for 50 rounds\n",
      "Early stopping, best iteration is:\n",
      "[244]\ttrain's auc: 0.827138\tvalid's auc: 0.768532\n"
     ]
    }
   ],
   "source": [
    "ycol = 'target'\n",
    "feature_names = list(\n",
    "    filter(lambda x: x not in [ycol, 'id', 'wave_index', 'tracking_id', 'expect_time', 'date', 'type', 'group',\n",
    "                               'courier_wave_start_lng', 'courier_wave_start_lat', 'shop_id', 'current_time_date'], df_train.columns))\n",
    "\n",
    "model = lgb.LGBMClassifier(num_leaves=64,\n",
    "                           max_depth=10,\n",
    "                           learning_rate=0.1,\n",
    "                           n_estimators=10000000,\n",
    "                           subsample=0.8,\n",
    "                           feature_fraction=0.8,\n",
    "                           reg_alpha=0.5,\n",
    "                           reg_lambda=0.5,\n",
    "                           random_state=seed,\n",
    "                           metric=None\n",
    "                           )\n",
    "\n",
    "\n",
    "oof = []\n",
    "prediction = df_test[['id', 'group']]\n",
    "prediction['target'] = 0\n",
    "df_importance_list = []\n",
    "\n",
    "kfold = GroupKFold(n_splits=5)\n",
    "for fold_id, (trn_idx, val_idx) in enumerate(kfold.split(df_train[feature_names], df_train[ycol], df_train['group'])):\n",
    "    X_train = df_train.iloc[trn_idx][feature_names]\n",
    "    Y_train = df_train.iloc[trn_idx][ycol]\n",
    "\n",
    "    X_val = df_train.iloc[val_idx][feature_names]\n",
    "    Y_val = df_train.iloc[val_idx][ycol]\n",
    "\n",
    "    print('\\nFold_{} Training ================================\\n'.format(fold_id+1))\n",
    "\n",
    "    lgb_model = model.fit(X_train,\n",
    "                          Y_train,\n",
    "                          eval_names=['train', 'valid'],\n",
    "                          eval_set=[(X_train, Y_train), (X_val, Y_val)],\n",
    "                          verbose=500,\n",
    "                          eval_metric='auc',\n",
    "                          early_stopping_rounds=50)\n",
    "\n",
    "    pred_val = lgb_model.predict_proba(\n",
    "        X_val, num_iteration=lgb_model.best_iteration_)[:, 1]\n",
    "    df_oof = df_train.iloc[val_idx][['id', 'group', ycol]].copy()\n",
    "    df_oof['pred'] = pred_val\n",
    "    oof.append(df_oof)\n",
    "\n",
    "    pred_test = lgb_model.predict_proba(\n",
    "        df_test[feature_names], num_iteration=lgb_model.best_iteration_)[:, 1]\n",
    "    prediction['target'] += pred_test / 5\n",
    "\n",
    "    df_importance = pd.DataFrame({\n",
    "        'column': feature_names,\n",
    "        'importance': lgb_model.feature_importances_,\n",
    "    })\n",
    "    df_importance_list.append(df_importance)\n",
    "\n",
    "    del lgb_model, pred_val, pred_test, X_train, Y_train, X_val, Y_val\n",
    "    gc.collect()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2020-03-21T08:00:55.362736Z",
     "start_time": "2020-03-21T08:00:55.352969Z"
    },
    "scrolled": true
   },
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>column</th>\n",
       "      <th>importance</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>aoi_id</td>\n",
       "      <td>2507.6</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>grid_distance</td>\n",
       "      <td>2468.6</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>target_lat</td>\n",
       "      <td>1594.6</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>speed</td>\n",
       "      <td>1542.8</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>current_time</td>\n",
       "      <td>1442.2</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>5</th>\n",
       "      <td>target_lng</td>\n",
       "      <td>1412.6</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>6</th>\n",
       "      <td>source_lng</td>\n",
       "      <td>1337.8</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>7</th>\n",
       "      <td>courier_id</td>\n",
       "      <td>1279.8</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>8</th>\n",
       "      <td>source_lat</td>\n",
       "      <td>1206.8</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>9</th>\n",
       "      <td>promise_deliver_time</td>\n",
       "      <td>1047.2</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>10</th>\n",
       "      <td>estimate_pick_time</td>\n",
       "      <td>854.4</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>11</th>\n",
       "      <td>last_tracking_id</td>\n",
       "      <td>734.4</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>12</th>\n",
       "      <td>max_load</td>\n",
       "      <td>532.4</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>13</th>\n",
       "      <td>action_type</td>\n",
       "      <td>315.6</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>14</th>\n",
       "      <td>last_action_type</td>\n",
       "      <td>240.4</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>15</th>\n",
       "      <td>level</td>\n",
       "      <td>149.6</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>16</th>\n",
       "      <td>weather_grade</td>\n",
       "      <td>100.8</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "                  column  importance\n",
       "0                 aoi_id      2507.6\n",
       "1          grid_distance      2468.6\n",
       "2             target_lat      1594.6\n",
       "3                  speed      1542.8\n",
       "4           current_time      1442.2\n",
       "5             target_lng      1412.6\n",
       "6             source_lng      1337.8\n",
       "7             courier_id      1279.8\n",
       "8             source_lat      1206.8\n",
       "9   promise_deliver_time      1047.2\n",
       "10    estimate_pick_time       854.4\n",
       "11      last_tracking_id       734.4\n",
       "12              max_load       532.4\n",
       "13           action_type       315.6\n",
       "14      last_action_type       240.4\n",
       "15                 level       149.6\n",
       "16         weather_grade       100.8"
      ]
     },
     "execution_count": 9,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "df_importance = pd.concat(df_importance_list)\n",
    "df_importance = df_importance.groupby(['column'])['importance'].agg(\n",
    "    'mean').sort_values(ascending=False).reset_index()\n",
    "df_importance"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2020-03-21T08:00:55.366921Z",
     "start_time": "2020-03-21T08:00:55.363747Z"
    }
   },
   "outputs": [],
   "source": [
    "def wave_label_func(group):\n",
    "    target_list = group['target'].values.tolist()\n",
    "    pred_list = group['pred'].values.tolist()\n",
    "    max_index = pred_list.index(max(pred_list))\n",
    "    if target_list[max_index] == 1:\n",
    "        return 1\n",
    "    else:\n",
    "        return 0"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2020-03-21T08:00:57.675769Z",
     "start_time": "2020-03-21T08:00:55.367948Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "acc: 0.6243134050569703\n"
     ]
    }
   ],
   "source": [
    "    df_oof = pd.concat(oof)\n",
    "    df_temp = df_oof.groupby(['group']).apply(wave_label_func).reset_index()\n",
    "    df_temp.columns = ['group', 'label']\n",
    "    acc = df_temp[df_temp['label'] == 1].shape[0] / df_temp.shape[0]\n",
    "    print('acc:', acc)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2020-03-21T08:00:58.534429Z",
     "start_time": "2020-03-21T08:00:57.676985Z"
    }
   },
   "outputs": [],
   "source": [
    "def label_func(group):\n",
    "    group = group.values.tolist()\n",
    "    max_index = group.index(max(group))\n",
    "    label = np.zeros(len(group))\n",
    "    label[max_index] = 1\n",
    "    return label\n",
    "\n",
    "\n",
    "prediction['label'] = prediction.groupby(\n",
    "    ['group'])['target'].transform(label_func)\n",
    "sub_part1 = prediction[prediction['label'] == 1]\n",
    "df_oof = df_oof[df_oof['target'] == 1]\n",
    "next_action = pd.concat([df_oof[['id']], sub_part1[['id']]])\n",
    "next_action.to_csv('./temp/next_action.csv', index=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python [conda env:dm] *",
   "language": "python",
   "name": "conda-env-dm-py"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.9"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
